import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

from emodel_generalisation.mcmc import load_chains, plot_cost, plot_cost_convergence
from emodel_generalisation.mcmc import plot_corner


if __name__ == "__main__":
    Path("figures").mkdir(exist_ok=True)
    run_df = pd.read_csv("run_df.csv")
    df = load_chains(run_df)
    print(df)
    print(df["cost"])
    df.loc[df.cost > 10, "cost"] = 10
    plot_cost_convergence(df, filename="figures/cost_convergence.png")

    split = 2.0

    max_feat = df[df.cost < split]["scores"].idxmax(axis=1).value_counts(ascending=True)

    plt.figure(figsize=(7, 9))
    max_feat.plot.barh(ax=plt.gca())
    plt.xscale("log")
    plt.tight_layout()
    plt.savefig("figures/worst_features.pdf")

    plot_cost(df, split, filename="figures/costs.pdf")
    _df = df[df.cost < split].reset_index(drop=True)

    """
    plot_corner(
        _df[_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] < 8000],
        feature=None,
        filename="figures/corner_non_runaway.pdf",
    )

    plot_corner(
        _df[_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 8000],
        feature=None,
        filename="figures/corner_runaway.pdf",
    )
    """

    plot_corner(
        _df,
        feature=None,
        filename="figures/corner.pdf",
    )

    plot_corner(
        _df,
        feature="cost",
        filename="figures/corner_cost.pdf",
    )

    Path("figures/corners_features").mkdir(exist_ok=True)
    Path("figures/corners_scores").mkdir(exist_ok=True)
    for feature in _df["scores"].columns:
        print("corner plot of ", feature)
        try:
            plot_corner(
                _df.reset_index(drop=True),
                feature=("features", feature),
                filename=f"figures/corners_features/corner_features_{feature}.pdf",
            )
            plot_corner(
                _df.reset_index(drop=True),
                feature=("scores", feature),
                filename=f"figures/corners_scores/corner_score_{feature}.pdf",
            )
            plt.close()
        except:
            pass
